import jax
import jax.numpy as np
import scalevi.distributions.distributions as dists
import scalevi.distributions.scale_transforms as scale_transforms
import scalevi.utils.utils as utils
import scalevi.models as models
import scalevi.models.models_base as models_base
import inspect

MODELS = [m for m, _ in 
                    inspect.getmembers(
                    models.models_branched,
                    lambda x : (
                        issubclass(x, models_base.ModelBranched)
                        or issubclass(x, models_base.Model)  
                        if inspect.isclass(x) 
                        else False))
                    ]

def get_data(config_dict, nb):
    if config_dict['model'] == "BranchGaussian":
        scale_tf = scale_transforms.ProximalScaleTransform(10.0).forward
        D_par, D_kid, N_leaves = list(map(
                config_dict.get, ['D_par', 'D_kid', 'N_leaves']))
        data = {}
        rng_key = jax.random.PRNGKey(0) 

        rng_key, rng_subkey = jax.random.split(rng_key)
        data['μ_par'] = jax.random.normal(rng_subkey,(D_par,))

        rng_key, rng_subkey = jax.random.split(rng_key)
        data['L_par'] = 5*np.eye(D_par) + scale_tf(0.1*jax.random.normal(
                                rng_subkey,(D_par, D_par)))

        rng_key, rng_subkey = jax.random.split(rng_key)
        data['μ_kid'] = jax.random.normal(
                                rng_subkey,(N_leaves, D_kid))

        rng_key, rng_subkey = jax.random.split(rng_key)
        data['L_kid'] = 5*np.eye(D_kid) + scale_tf(0.1*jax.random.normal(
                                rng_subkey,(N_leaves, D_kid, D_kid)))

    elif "Branch" in config_dict['model']:
        data = {}
        D_data, N_leaves, M_data = utils.dict_get(config_dict, 
                                ['D_data', 'N_leaves', "M_data"])
        data.update({
            'D_par': D_data,
            'D_kid': D_data,
            'use_mask': config_dict.get("use_mask", False),
        })
        rng_key=jax.random.PRNGKey(0)
        data.update(models.get_model_class(config_dict)._forward_sample(
                                                            rng_key,
                                                            M_data,
                                                            N_leaves,
                                                            D_data,
                                                            data['use_mask'])) 
    else:
        raise NotImplementedError(
                    f"Could not find the dataloader. "
                    f"Expected one of {[MODELS]}"
                    f" but got {config_dict['model']}")

    return data

